'''
Author: SlytherinGe
LastEditTime: 2021-03-02 20:00:47
'''
import cv2 
import numpy as np



def tiff_16bit_img_read_and_normalize(img, gamma=None, equalized=True):
    im = cv2.imread(img, cv2.IMREAD_LOAD_GDAL)
    im_max = im.max()
    im_min = im.min()
    aimg = np.asarray(im)
    # 计算直方图
    clr_lvls = im_max - im_min + 1
    hist = np.zeros(clr_lvls)
    for i in range(aimg.shape[0]):
        for j in range(aimg.shape[1]):
            hist[aimg[i,j] - im_min] += 1
    # 选取截取的灰度级
    total_pxl = aimg.shape[0] * aimg.shape[1]
    cut_low = total_pxl * 0.001     # 修改这边的数值可以改变截取的最大与最小值
    cut_high = total_pxl * 0.999
    sum_hist = np.zeros_like(hist)
    sum_hist[0] = hist[0]
    cut_min, cut_max = (0, 0)
    for i in range(1,clr_lvls):
        sum_hist[i] = sum_hist[i - 1] + hist[i]
        if sum_hist[i - 1] <= cut_high and sum_hist[i] >= cut_high:
            cut_max = i + im_min
        if sum_hist[i - 1] <= cut_low and sum_hist[i] >= cut_low:
            cut_min = i + im_min
    mask_high = aimg > cut_max
    mask_min = aimg < cut_min
    aimg[mask_high] = cut_max
    aimg[mask_min] = cut_min
    aimg_norm = (aimg - cut_min) / (cut_max - cut_min)
    # 对图像进行gamma矫正
    if gamma is not None:
        aimg_norm = np.power(aimg_norm, gamma)
    # 将图像压缩至8bit
    aimg = np.uint8(aimg_norm * 255)
    # 归一化
    if equalized:
        aimg = cv2.equalizeHist(aimg)
    rgb_img = np.repeat(np.expand_dims(aimg, axis=2), 3, 2) 
    return rgb_img   

def tiff_16bit_img_read_and_colorize(img, high_bits, low_bits, is_channel_wise_equalized=False, gamma_high=None, gamma_low=None):
    '''
    convert a 16 bit img into a RGB img
    high_bits, low_bits represent which RGB channel high or low 8 bits be converted to
    2: red, 1:green, 0:blue 
    '''
    im = cv2.imread(img, cv2.IMREAD_LOAD_GDAL)
    aim = np.asarray(im)
    aimg = np.zeros((im.shape[0], im.shape[1], 3))

    img_hi = np.uint8(np.right_shift(aim, 8))
    img_lo = np.uint8(np.bitwise_and(aim, 255))

    if is_channel_wise_equalized:
        img_hi = cv2.equalizeHist(img_hi)
        img_lo = cv2.equalizeHist(img_lo)

    aimg[:, :, high_bits] = img_hi
    aimg[:, :, low_bits] = img_lo

    return aimg


# if __name__ == "__main__":
#     img = tiff_16bit_img_read_and_normalize("/media/gejunyao/Disk1/Datasets/AIR-SARship/SARShip/SARShip-1.0-13/SARShip-1.0-13.tiff", 1.2, True)
#     plt.figure()
#     plt.imshow(img[:,:,0], 'gray')
#     plt.show()
#     cv2.imwrite('/media/gejunyao/Disk1/Datasets/AIR-SARship/SARShip/SARShip-1.0-13/SARShip.jpg', img)

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    IMG_ID = '13'
    IMG = "/media/gejunyao/Disk1/Datasets/AIR-SARship/SARShip/SARShip-1.0-" +IMG_ID + "/SARShip-1.0-" + IMG_ID + ".tiff"
    img_gray = tiff_16bit_img_read_and_normalize(IMG, 1.2, True)
    img = tiff_16bit_img_read_and_colorize(IMG, 2, 0, True)
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(img/255)
    plt.subplot(1,2,2)
    plt.imshow(img_gray[:,:,0], cmap='gray')
    plt.show()
    # cv2.imwrite('/media/gejunyao/Disk1/Datasets/AIR-SARship/SARShip/SARShip-1.0-13/SARShip_color.jpg', img)